{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset-bpe.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "EOS = 2\n",
    "GO = 1\n",
    "vocab_size = 32000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Y = [i + [2] for i in train_Y]\n",
    "test_Y = [i + [2] for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import beam_search\n",
    "\n",
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, learning_rate):\n",
    "        \n",
    "        def cells(reuse=False):\n",
    "            return tf.nn.rnn_cell.GRUCell(size_layer,reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        embeddings = tf.Variable(tf.random_uniform([vocab_size, embedded_size], -1, 1))\n",
    "        \n",
    "        _, encoder_state = tf.nn.dynamic_rnn(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)]), \n",
    "            inputs = tf.nn.embedding_lookup(embeddings, self.X),\n",
    "            sequence_length = self.X_seq_len,\n",
    "            dtype = tf.float32)\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        dense = tf.layers.Dense(vocab_size)\n",
    "        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)])\n",
    "        \n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = encoder_state,\n",
    "                output_layer = dense)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        \n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = encoder_state,\n",
    "                output_layer = dense)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.fast_result = predicting_decoder_output.sample_id\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 512\n",
    "num_layers = 2\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 128\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n",
      "WARNING:tensorflow:From <ipython-input-7-e07e70632a49>:11: GRUCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.GRUCell, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:From <ipython-input-7-e07e70632a49>:23: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:From <ipython-input-7-e07e70632a49>:26: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:559: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:565: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:575: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn.py:244: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator(size_layer, num_layers, embedded_size, learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[ 9474,  9474,  8026,  8026, 11354, 11354, 30856, 19043, 25768,\n",
       "         25768,  5578, 30394, 30394, 30394, 30394, 30394, 31105, 31105,\n",
       "          5803,  1253,  1253,  1253, 29253, 29253, 29253, 29253, 29253,\n",
       "         10336, 10336,  6833,  6833,  5621,  5621,  5621,  5621,  5621,\n",
       "          5621,  5621, 14544, 14544,  7972,  7972,  7972,  7972, 19354,\n",
       "         19354, 19354, 19354, 22886, 28105,  7491,  7491,  7491,  7491,\n",
       "          2215,  2215,  1830,  1830,  1830, 27806,  1664,  1664,  1664,\n",
       "          7601,  7601,  7601,  6772,  6772, 31802, 31802, 31802, 26975],\n",
       "        [19758,  2395, 11639, 11639,  3554, 14988, 14988, 16578, 16578,\n",
       "         28519, 28519, 29929, 29929,  8231,  8231, 19925,  3138,  3138,\n",
       "         26914, 26338, 27344, 27344, 27344, 27344, 27344, 13816, 15748,\n",
       "         15748, 15748, 16686, 16095, 16095, 16095, 16095,  4248, 24085,\n",
       "         24085,  5160, 24085, 24085,  5866,  6372,  6372,  6372,  6372,\n",
       "           935, 26480, 26480, 26480, 26480,  2037, 21200,   155,   155,\n",
       "         13017, 13017, 13017,  6820,   233,  6820,  6820, 31387, 14536,\n",
       "         10176, 10176, 10176, 10176, 30230, 31099, 30230, 31099, 25218],\n",
       "        [ 4447,  4447,   417, 25156, 25156, 25156, 25156, 25156, 25156,\n",
       "          2756, 25156,  3504, 25957, 25957,  4286, 23244, 23244, 23244,\n",
       "          1036,   638, 22027, 25615, 25615, 15029, 25615,  2538, 20524,\n",
       "           883,   883,  4898,  4898,  4898,  6984, 30983, 16765, 16765,\n",
       "         16765, 16765, 28003,  7360,  7360, 23123, 23123, 18384, 27342,\n",
       "         27342, 27342, 20580, 20580, 20580, 13977, 13977, 13977, 13977,\n",
       "         28319, 28319, 17141, 17141, 17141, 15452,  9311,  9311,  9311,\n",
       "         24181, 24181, 25791, 25791,  7274,  7274,  4516, 10100, 10100],\n",
       "        [26050, 16590, 16590, 16590, 16590,  5385,  5385,  5385,   172,\n",
       "         15661, 15661, 17136, 22151, 22151, 27348, 18651,  6074,  6074,\n",
       "          6074,  6074,  6074,  6074, 12982, 12982, 12982,  6521, 29972,\n",
       "         29972,  6511,  7819,  7819,  2609, 14359, 14359, 14359,  3705,\n",
       "          9541, 25303,  9541, 25303,  4792,  2979, 17462, 17462, 17462,\n",
       "         17462, 17462,  5527, 15841, 15841, 15396, 15396, 21546, 16237,\n",
       "         16237, 16237, 16237,  2799,  2799,  2799, 13780, 23869, 23869,\n",
       "         24792, 24792, 13022, 13022, 13022,  9865,  7696, 22248, 31606],\n",
       "        [19979, 20579, 22020, 22020,   576, 19948, 19948, 10791,  5248,\n",
       "          5248,  5248,  1185,  1185, 20392, 20392, 20392, 20392, 20392,\n",
       "         31140, 31140, 18341, 18891, 18891, 18891, 18891, 18891, 31792,\n",
       "         31792, 31792, 31792, 31792,  5547,  5547,  5547,  5060,  5060,\n",
       "          5060, 18725, 18725, 18725,     3,  5237,  5237, 27163, 27163,\n",
       "         27163, 27163, 27163, 29317,  9957,  9957,  9957, 17969, 17969,\n",
       "         17969, 17969, 15075,  1140,  1140, 30802, 30802, 31919, 31919,\n",
       "         31919, 31919, 12139,  6761,  2441,  2441,   382,   382,   382],\n",
       "        [14415,  9829,  1058,  1058, 23490, 23490, 23490, 23490, 23490,\n",
       "         23490,  6579,  6579, 13956, 13956, 17003, 24228, 24228, 24228,\n",
       "         24228, 11307, 11307,  9794,  9794,  9794, 24817, 24817, 16903,\n",
       "         14146, 14146, 14146, 23557, 23557, 23557, 29409, 18324, 16263,\n",
       "          7605,  7605, 20814, 20814, 20814, 20814, 20814,  5027,  5027,\n",
       "          5027,  5027,  5027, 18909,   893,  8683,  8109,  8109,  8109,\n",
       "         27426, 27426, 27426, 27426, 13855, 13855, 13855, 13855, 31722,\n",
       "         31722, 31722, 31722, 25075, 25075, 22512, 20703, 20703, 20703],\n",
       "        [22911, 18466, 18466, 12076,  4050,  4050,  4050, 19813, 19813,\n",
       "         19813, 19813,  2865,  2865,  2865, 14321, 14321, 14321, 14321,\n",
       "         14321, 14321,  4438,  4438,  4438,  4438,  8582,  8582,  8582,\n",
       "          8582, 26578, 11339, 11339, 11339, 11339, 11339, 11339, 11339,\n",
       "         11339, 13361, 13361, 26873, 26873, 26873, 22957, 22957, 13748,\n",
       "         13748, 13748,  5623,  5623,  5623, 17672, 17672, 17672, 15028,\n",
       "         15028, 15028,  8560,  8560, 30564, 29567, 29567, 30564, 10013,\n",
       "         10013, 10013, 10013, 10013, 10013,  3026,  3026,  3026,  3026],\n",
       "        [10577, 29066, 29066, 25428, 13980, 26499, 25428, 12203, 12203,\n",
       "         10868, 10868, 27443, 27443,   370,   370, 12663, 12663, 26829,\n",
       "          2433,  2433,  2433,  2916, 16306, 16306, 31577,  4770, 18127,\n",
       "         18127, 18127,  8767, 21902, 21902, 21902,  3602, 10718, 10718,\n",
       "         10718, 10718,  6790, 10718, 17278, 25493, 14993, 14993, 14993,\n",
       "         14993,  3650, 18070, 18070,  1424,  1424,  3362,  3362,  3362,\n",
       "         19308, 29145, 27664, 14634, 14634, 18000, 25387, 29952, 29952,\n",
       "         29952, 29952, 23214, 23214,  1483, 20303, 20303, 28586, 28586],\n",
       "        [29893, 29893, 10716, 24337, 24337, 24259, 24259, 24259, 29739,\n",
       "         29739,  6923,  6923,  2749,  2749,  2749,  8470,  8470,  3967,\n",
       "          3967,  3967,  3967,  3967,  3967, 19957, 19957, 19957, 14301,\n",
       "         14301, 14562, 14562,  9031, 18729, 18729, 18729, 18729, 18729,\n",
       "         18729, 22677, 22677,  4087,  4087,  4087, 30576,  4323,  4323,\n",
       "          4323,  4323,  3100,  3100,  3100,  3100, 13917, 13917, 15060,\n",
       "         15060, 15060, 11862, 11862, 25957,  3450,  3450,  3450,  3450,\n",
       "         21986, 21986, 21986, 29361, 24712, 26797, 28380, 28380, 28380],\n",
       "        [30362, 24507,  3881,  3881,   645,   645,   645,   645,   645,\n",
       "           645,  5768,  5768,  5768,  5768, 29522, 25843, 25843, 25843,\n",
       "         25843, 26413, 26413, 26413, 26413, 26413, 26413, 19608, 19608,\n",
       "          8723,  8723,  8723,  8723,  8723, 19173, 19173, 19173, 26717,\n",
       "         26717, 30606, 30606,  1418, 29528, 12212, 12212, 12212, 10802,\n",
       "          2835,  2835,  2835, 20336,  9931, 28285, 28285, 28285, 22681,\n",
       "         22681, 16125, 24028, 24028, 24028, 24028, 27922, 27922,  2786,\n",
       "          2786, 10073, 10073, 10073, 10073, 20427, 20427, 20427, 20427]],\n",
       "       dtype=int32), 10.371124, 0.0]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x = pad_sequences(train_X[:10], padding='post')\n",
    "batch_y = pad_sequences(train_Y[:10], padding='post')\n",
    "\n",
    "sess.run([model.fast_result, model.cost, model.accuracy], \n",
    "         feed_dict = {model.X: batch_x, model.Y: batch_y})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:43<00:00,  1.66it/s, accuracy=0.258, cost=4.66]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.94it/s, accuracy=0.328, cost=4.03]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg loss 5.760781, training avg acc 0.176343\n",
      "epoch 1, testing avg loss 4.482817, testing avg acc 0.264885\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:33<00:00,  1.67it/s, accuracy=0.34, cost=3.73] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:11<00:00,  3.61it/s, accuracy=0.344, cost=3.47]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg loss 4.051785, training avg acc 0.305687\n",
      "epoch 2, testing avg loss 3.847744, testing avg acc 0.329494\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:53<00:00,  1.64it/s, accuracy=0.408, cost=3.05]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s, accuracy=0.376, cost=3.28]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg loss 3.436246, training avg acc 0.368477\n",
      "epoch 3, testing avg loss 3.616866, testing avg acc 0.357988\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:26<00:00,  1.69it/s, accuracy=0.476, cost=2.51]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.97it/s, accuracy=0.382, cost=3.32]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg loss 3.047151, training avg acc 0.412772\n",
      "epoch 4, testing avg loss 3.551020, testing avg acc 0.367325\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:27<00:00,  1.69it/s, accuracy=0.549, cost=2.13]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.99it/s, accuracy=0.409, cost=3.27]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg loss 2.761674, training avg acc 0.448493\n",
      "epoch 5, testing avg loss 3.559103, testing avg acc 0.369203\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:56<00:00,  1.63it/s, accuracy=0.608, cost=1.82]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s, accuracy=0.425, cost=3.3] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg loss 2.536371, training avg acc 0.479370\n",
      "epoch 6, testing avg loss 3.606942, testing avg acc 0.367298\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:30<00:00,  1.68it/s, accuracy=0.679, cost=1.52]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.06it/s, accuracy=0.382, cost=3.45]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg loss 2.353640, training avg acc 0.505964\n",
      "epoch 7, testing avg loss 3.691811, testing avg acc 0.361439\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:27<00:00,  1.69it/s, accuracy=0.687, cost=1.39]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s, accuracy=0.409, cost=3.38]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg loss 2.203858, training avg acc 0.528740\n",
      "epoch 8, testing avg loss 3.778690, testing avg acc 0.358032\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:30<00:00,  1.68it/s, accuracy=0.724, cost=1.25]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.00it/s, accuracy=0.419, cost=3.45]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg loss 2.080045, training avg acc 0.547421\n",
      "epoch 9, testing avg loss 3.848859, testing avg acc 0.358323\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:39<00:00,  1.66it/s, accuracy=0.744, cost=1.1] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.94it/s, accuracy=0.398, cost=3.55]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg loss 1.969356, training avg acc 0.564852\n",
      "epoch 10, testing avg loss 3.939119, testing avg acc 0.354124\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:29<00:00,  1.68it/s, accuracy=0.763, cost=1.04]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.00it/s, accuracy=0.398, cost=3.69]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg loss 1.878914, training avg acc 0.578619\n",
      "epoch 11, testing avg loss 4.053656, testing avg acc 0.348667\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:28<00:00,  1.68it/s, accuracy=0.778, cost=0.95]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s, accuracy=0.382, cost=3.72]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg loss 1.805559, training avg acc 0.589485\n",
      "epoch 12, testing avg loss 4.125336, testing avg acc 0.345776\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:28<00:00,  1.68it/s, accuracy=0.781, cost=0.927]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s, accuracy=0.387, cost=3.84]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg loss 1.728404, training avg acc 0.602297\n",
      "epoch 13, testing avg loss 4.207379, testing avg acc 0.342252\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:29<00:00,  1.68it/s, accuracy=0.784, cost=0.882]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.97it/s, accuracy=0.382, cost=4]   \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg loss 1.660186, training avg acc 0.613397\n",
      "epoch 14, testing avg loss 4.312093, testing avg acc 0.340908\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:26<00:00,  1.69it/s, accuracy=0.803, cost=0.819]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s, accuracy=0.387, cost=4.01]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg loss 1.607585, training avg acc 0.621560\n",
      "epoch 15, testing avg loss 4.369632, testing avg acc 0.343075\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:28<00:00,  1.68it/s, accuracy=0.811, cost=0.773]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.04it/s, accuracy=0.419, cost=3.89]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg loss 1.562124, training avg acc 0.628399\n",
      "epoch 16, testing avg loss 4.445291, testing avg acc 0.341184\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:31<00:00,  1.68it/s, accuracy=0.805, cost=0.771]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s, accuracy=0.371, cost=4.04]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg loss 1.516494, training avg acc 0.635623\n",
      "epoch 17, testing avg loss 4.519156, testing avg acc 0.337541\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:43<00:00,  1.66it/s, accuracy=0.811, cost=0.756]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.00it/s, accuracy=0.355, cost=4.28]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg loss 1.472787, training avg acc 0.642899\n",
      "epoch 18, testing avg loss 4.615767, testing avg acc 0.332683\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:29<00:00,  1.68it/s, accuracy=0.827, cost=0.694]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.03it/s, accuracy=0.349, cost=4.27]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg loss 1.434646, training avg acc 0.649070\n",
      "epoch 19, testing avg loss 4.687762, testing avg acc 0.331642\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:43<00:00,  1.66it/s, accuracy=0.816, cost=0.742]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  3.02it/s, accuracy=0.376, cost=4.27]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg loss 1.405794, training avg acc 0.653162\n",
      "epoch 20, testing avg loss 4.746929, testing avg acc 0.333465\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = pad_sequences(train_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(train_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(test_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 40/40 [00:14<00:00,  2.70it/s]\n"
     ]
    }
   ],
   "source": [
    "from tensor2tensor.utils import bleu_hook\n",
    "\n",
    "results = []\n",
    "for i in tqdm.tqdm(range(0, len(test_X), batch_size)):\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "    feed = {model.X: batch_x}\n",
    "    p = sess.run(model.fast_result,feed_dict = feed)\n",
    "    result = []\n",
    "    for row in p:\n",
    "        result.append([i for i in row if i > 3])\n",
    "    results.extend(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "rights = []\n",
    "for r in test_Y:\n",
    "    rights.append([i for i in r if i > 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.051461186"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bleu_hook.compute_bleu(reference_corpus = rights,\n",
    "                       translation_corpus = results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
